#from rmpoly import *
from rmpoly import RPoly, Poly, RPolyOverflowError

def S_poly(tp1,tp2):
  """expv1,p1 = tp1 with expv1 = p1.leading_expv(), p1 monic; 
  similarly for tp2.
  Compute LCM(LM(p1),LM(p2))/LM(p1)*p1 - LCM(LM(p1),LM(p2))/LM(p2)*p2
  Throw RPolyOverflowError if bits_exp is too small for the result.
  """
  expv1,p1 = tp1
  expv2,p2 = tp2
  rp = p1.rp
  maskof = rp.maskof
  mask_exp = rp.mask_exp
  bits_exp = rp._bits_exp

  rexpv1 = rexpv2 = 0
  if rp.ordertype:
    ordertuple=rp.ordertuple
    mask1 = mask_exp<<(rp.ngens*bits_exp)
    deg = 0
    for i in range(rp.ngens):
      mask = mask_exp<<(i*bits_exp)
      m = max(expv1&mask,expv2&mask)
      deg += (m>>(i*bits_exp))*ordertuple[i]
      rexpv1 += m-(expv1&mask)
      rexpv2 += m-(expv2&mask)
    rexpv1 += (deg<<(rp.ngens*bits_exp)) - (expv1&mask1)
    rexpv2 += (deg<<(rp.ngens*bits_exp)) - (expv2&mask1)
    res = Poly(rp)
    res.iadd_m_mul_q(p1,(rexpv1, 1))
    res.iadd_m_mul_q(p2,(rexpv2, -1))
    resexpv = res.leading_expv()
    if rp.order == 'grlex':
      if resexpv and resexpv&maskof != 0:
        raise RPolyOverflowError('S_poly')
    else:
      for expv in res:
        if expv and expv&maskof != 0:
          raise RPolyOverflowError('S_poly')
  else:
    for i in range(rp.ngens):
      mask = mask_exp<<(i*bits_exp)
      m = max(expv1&mask,expv2&mask)
      rexpv1 += m-(expv1&mask)
      rexpv2 += m-(expv2&mask)
    res = Poly(rp)
    res.iadd_m_mul_q(p1,(rexpv1, 1))
    res.iadd_m_mul_q(p2,(rexpv2, -1))
    for expv in res:
      if expv and expv&maskof != 0:
        raise RPolyOverflowError('S_poly')
  return res


def groebner_basis(f, verbose=0):
  """An improved version of Buchberger's algorithm as presented in
  T. Becker, V.Weispfenning 'Groebner basis' (1993) Springer, page 232; 
  see also buchberger_improved in toy_buchberger.py in Sage
  input: f sequence of polynomial
  output: Groebner basis of the ideal generated by f
  """

  def select(P):
    # select the pair with minimum LCM
    pr = min(P, key = lambda(i,j): lcm_expv(f[i][0],f[j][0]))
    return pr

  def normal(g, H):
    """
    compute the rest h of the division of g wrt the functions in H;
    if the rest is zero return None
    else if h is not in f add it to f; return its (expv,p)
    """
    h = g.mod1([f[i] for i in H])
    if h == zero:
      return None
    else:
      hk = tuple(h.keys())
      # add h to SP, return (expv,pi)
      if not hk in fd:
        fd[hk] = len(f)
        hexpv = h.leading_expv()
        f.append((hexpv,h/h[hexpv]))
        return hexpv, fd[hk]
      return f[fd[hk]][0], fd[hk]


  def update(G,CP,h):
    """update G using the set of critical pairs CP and h = (expv,pi)
    see [BW] page 230
    """
    hexpv, hp = f[h]
    # filter new pairs (h,g), g in G
    C = G.copy()
    D = set()

    while C:
      # select a pair (h,g) by popping an element from C
      g = C.pop()
      gexpv = f[g][0]
      LCMhg = lcm_expv(hexpv, gexpv)

      def lcm_divides(p):
        expv = lcm_expv(hexpv, f[p][0])
        # LCM(LM(h), LM(p)) divides LCM(LM(h),LM(g))
        return (LCMhg - expv)&maskof == 0
      
      # HT(h) and HT(g) disjoint: hexpv + gexpv == LCMhg
      if hexpv + gexpv == LCMhg or (\
      not any( lcm_divides(f) for f in C ) and \
      not any( lcm_divides(pr[1]) for pr in D )):
        D.add((h,g))


    E = set()
    while D:
      # select h,g from D
      h,g = D.pop()
      gexpv = f[g][0]
      LCMhg = lcm_expv(hexpv, gexpv)
      if not hexpv + gexpv == LCMhg:
        E.add((h,g))

    # filter old pairs
    B_new = set()

    while CP:
      # select g1,g2 from CP
      g1,g2 = CP.pop()
      g1expv = f[g1][0]
      g2expv = f[g2][0]
      LCM12 = lcm_expv(g1expv,g2expv)
      # if HT(h) does not divide lcm(HT(g1),HT(g2))
      if not (LCM12 - hexpv)&maskof == 0 or \
          lcm_expv(g1expv,hexpv) == LCM12 or \
          lcm_expv(g2expv,hexpv) == LCM12:
        B_new.add((g1,g2))

    B_new |= E

    # filter polynomials
    G_new = set()
    while G:
      g = G.pop()
      if not (f[g][0] - hexpv)&maskof == 0:
        G_new.add(g)
    G_new.add(h)

    return G_new,B_new
  # end of update ################################

  if not f:
    return None
  rp = f[0].rp
  zero = Poly(rp)
  maskof = rp.maskof

  # lcm_expv(expv1,expv2) computes the expv for the lcm
  # of the monomials with expv1,expv2; the results are cached
  lcm_expv0 = rp.lcm_expv
  d_lcm_expv = {}
  def lcm_expv(expv1,expv2):
    if not (expv1,expv2) in d_lcm_expv:
      d_lcm_expv[(expv1,expv2)] = lcm_expv0(expv1,expv2)
    return d_lcm_expv[(expv1,expv2)]

  # replace f with a list of (p.leading_expv(),p), where p is monic
  # and all polynomials have different sets of monomials.
  # In this way, p is identified by pk = tuple(p.keys())
  # p is not hashable, so that one cannot use a built-in set of (expv,p)
  # To implement a set of polynomials SP use a dictionary fd
  # add p to SP:
  # f.append((expv,p)); fd[pk] = len(f)
  # ip is the number associated to p
  # expv,p = f[ip]

  # reduce the list of initial polynomials; see [BW] page 203
  f1 = f[:]
  while 1:
    f = f1[:]
    f1 = []
    for i in range(len(f)):
      p = f[i]
      _, r = p.division(f[:i])
      if r != 0:
        f1.append(r)
    # when f does not change anymore, there are not two elements with 
    # same LT, so the elements of f are guaranteed to have all
    # different sets of monomials
    if f == f1:
      break

  # convert f in a list of pairs (expv,p) where expv is the encoded
  # tuple of exponents of the LT of p and p is a monic polynomial
  f1 = []
  for h in f:
    if h:
      expv = h.leading_expv()
      f1.append((expv,h/h[expv]))
  f = f1

  # order according to the monomial ordering the initial polynomials
  # f[i] < f[j] if i > j
  f.sort(reverse=True)


  # f list of pairs (expv,p)
  fd = {}   # ip = fd[tuple(p.keys())]; (expv,p) = f[ip]
  F = set() # set of indices of polynomials
  G = set() # set of indices of intermediate would-be Groebner basis
  CP = set() # set of pairs of indices of critical pairs
  for i, h in enumerate(f):
    fd[tuple(h[1].keys())] = i
    F.add(i)

  #####################################
  #  algorithm GROEBNERNEWS2 in [BW] page 232
  while F:
    # select p with minimum expv
    m = min([f[x] for x in F])[1]
    h = fd[tuple(m.keys())]
    F.remove(h)
    G,CP = update(G,CP,h)

  # count the number of critical pairs which reduce to zero
  reductions_to_zero = 0

  while CP:
    g1,g2 = select(CP)
    CP.remove((g1,g2))
    h = S_poly(f[g1],f[g2])
    # normal(h,G) appends h to f if h
    G1 = list(G)
    G1.sort(key=lambda g: f[g][0])
    h = normal(h,G1)
    if h:
      G, CP = update(G,CP,h[1])
    else:
      reductions_to_zero += 1
  ######################################
  # now G is a Groebner basis; reduce it
  Gr = set()
  for g in G:
    h = normal(f[g][1], G - set([g]))
    if h:
      Gr.add(h[1])
  # replace ip with (expv,p)
  Gr = [f[g] for g in Gr]

  # order according to the monomial ordering
  Gr.sort(reverse=True)

  # replace (expv,p) with p
  Gr = [ x[1] for x in Gr]
  if verbose:
    print 'reductions_to_zero=',reductions_to_zero
  return Gr

def lcm(p1,p2):
  """least common multiple of two polynomials
  see [CLO] D. Cox, J. Little, D.O'Shea 
    'Ideals, varieties and algorithms, 3rd Ed., p188-189'
  >>> from rmpoly import *
  >>> from fractions import Fraction as mpq
  >>> rp,y,x = rgens('y,x',6,mpq)
  >>> p1 = (x-1)**2*(x-2)*(x-3)**2
  >>> p2 = (x-1)*(x-2)**2*(x-4)
  >>> p3 = lcm(p1,p2)
  >>> assert p3 == p1*(x-2)*(x-4)
  >>> p1 = (x-y)**2*(y-1)
  >>> p2 = (x-y)*(y-2)
  >>> p3 = lcm(p1,p2)
  >>> assert p3 == p1*(y-2)
  """
  rp = p1.rp
  rp1 = RPoly(rp.pol_gens+('__t',), rp.bits_exp+10,rp.field)
  __t = rp1.gens()[-1]
  p1a = rp1.new(p1)
  p2a = rp1.new(p2)
  I = [__t*p1a,(1-__t)*p2a]
  gr = groebner_basis(I)
  #print 'DB1 gr=',gr
  res = rp.new(gr[-1])
  return res

def gcd(p1,p2):
  """greatest common divisor of two polynomials
  >>> from rmpoly import *
  >>> from fractions import Fraction as mpq
  >>> rp,y,x = rgens('y,x',6,mpq)
  >>> p1 = (x-1)**2*(x-2)*(x-3)**2
  >>> p2 = (x-1)*(x-2)**2*(x-4)
  >>> p3 = gcd(p1,p2)
  >>> assert p3 == (x-1)*(x-2)
  >>> p1 = (x-y)**2*(y-1)
  >>> p2 = (x-y)*(y-2)
  >>> gcd(p1,p2)
   +x -y
  """
  #TODO faster case in which p1, p2 are univariate polynomials
  p3 = p1*p2
  lcm12 = lcm(p1,p2)
  q,r = p3.division([lcm12])
  assert not r
  return q[0]



if __name__ == "__main__":
    import doctest
    import sys
    if sys.version_info < (2, 6):
      print 'doctests require Fraction, available from Python2.6'
      sys.exit()
    doctest.testmod()

